/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*!
 * \file flash_attention_score_grad_tiling_bngs1s2_b.cc
 * \brief
 */

#include "flash_attention_score_grad_tiling_common.h"
#include "tiling/tiling_base.h"
#include "tiling/tiling_type.h"
#include "tiling/tiling_templates_registry.h"
#include "flash_attention_score_grad_tiling_bngs1s2_b_def.h"

using namespace ge;
using namespace AscendC;

namespace optiling {

constexpr int64_t BSH_SBH_DIM_NUM = 3;
constexpr int64_t BNSD_DIM_NUM = 4;
constexpr uint32_t DIM_COUNT_NUM_4 = 4;
constexpr int64_t ATTEN_MASK_TYPE_11SS_DIM_NUM = 2;
constexpr int64_t WORK_SPACE_BASE_CAL = 16 * 1024 * 1024;
constexpr int64_t FP32_BYTES_NUM = 4;
constexpr int64_t FP16_BYTES_NUM = 2;
constexpr int64_t FP16_BLOCK_ELES = 16;
constexpr int64_t FP32_BLOCK_ELES = 8;
constexpr int64_t BLOCK_BYTE = 32;
constexpr int64_t PER_SUB_RANGE = 8;
constexpr int64_t CV_RATIO = 2;
constexpr int64_t MM_CV_RATIO = 1;
constexpr int64_t BEST_BASIC_BLOCK_SIZE = 64 * 128 * 4;
constexpr int64_t BEST_BASIC_BLOCK_NUM = 64 * 128;
constexpr int64_t L1_BYTE_SIZE = 512 * 1024;
constexpr int64_t CHECKNUM_FOR_L1_SIZE = 3;
constexpr uint32_t MAX_STRIDE_LIMIT = 65535;
constexpr uint32_t POST_COEX_NODE = 3;
constexpr uint32_t BASE_LEN_256 = 256;
constexpr uint32_t WORKSPACE_ALIGN_SIZE = 512;
constexpr uint32_t BUFFER_NUM = 1;

#define CHECK_ZERO(num)                       \
  do {                                        \
    if ((num) == 0) {                         \
      OPS_LOG_E("Number[%s] is zero.", #num); \
      return ge::GRAPH_FAILED;                \
    }                                         \
  } while (0)

template <typename... Args>
constexpr uint64_t GET_B_TILINGKEY(Args... templateIds) {
  return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}

struct TempParamsUngs1s2Bb {
  uint32_t dataType;
  uint32_t precisionMode;
  uint32_t layout;
  uint32_t attenMaskCompressMode;
  uint32_t attenMaskS1Size = 0;
  uint32_t attenMaskS2Size = 0;
  uint32_t b;
  uint32_t n;
  uint32_t d;
};

class FlashAttentionScoreGradUbngs1s2BbTiling : public TilingBaseClass {
 public:
  FlashAttentionScoreGradUbngs1s2BbTilingData td_;
  TempParamsUngs1s2Bb basicParams;

  explicit FlashAttentionScoreGradUbngs1s2BbTiling(gert::TilingContext* context) : TilingBaseClass(context){};

  ~FlashAttentionScoreGradUbngs1s2BbTiling() override = default;

  inline uint32_t Align(const uint32_t n, const uint32_t alignSize) {
    return (n + alignSize - 1) & (~(alignSize - 1));
  }

  uint64_t GetTilingKey() const override {
    uint64_t normalTilingKey = 0;
    uint64_t specMMTilingKey = 0;
    uint64_t tilingKey = 0;

    if (basicParams.dataType == static_cast<uint32_t>(ge::DT_FLOAT16)) {
      normalTilingKey = GET_TILINGKEY(AxisEnum::NONE, AxisEnum::NONE, AxisEnum::B, DtypeEnum::FLOAT16_PRECISION,
                                      basicParams.layout, SparseEnum::NONE);
      specMMTilingKey = GET_B_TILINGKEY(AxisEnum::NONE, AxisEnum::NONE, AxisEnum::B, DtypeEnum::FLOAT16_PRECISION,
                                        basicParams.layout, SparseEnum::NONE, MatmulConfig::NORMAL_CONFIG);
    } else if (basicParams.dataType == static_cast<uint32_t>(ge::DT_BF16)) {
      normalTilingKey = GET_TILINGKEY(AxisEnum::NONE, AxisEnum::NONE, AxisEnum::B, DtypeEnum::BFLOAT16,
                                      basicParams.layout, SparseEnum::NONE);
      specMMTilingKey = GET_B_TILINGKEY(AxisEnum::NONE, AxisEnum::NONE, AxisEnum::B, DtypeEnum::BFLOAT16,
                                        basicParams.layout, SparseEnum::NONE, MatmulConfig::NORMAL_CONFIG);
    } else {
      normalTilingKey = GET_TILINGKEY(AxisEnum::NONE, AxisEnum::NONE, AxisEnum::B, DtypeEnum::FLOAT32,
                                      basicParams.layout, SparseEnum::NONE);
      specMMTilingKey = GET_B_TILINGKEY(AxisEnum::NONE, AxisEnum::NONE, AxisEnum::B, DtypeEnum::FLOAT32,
                                        basicParams.layout, SparseEnum::NONE, MatmulConfig::NORMAL_CONFIG);
    }

    tilingKey = normalTilingKey;
    if (basicParams.layout == static_cast<uint32_t>(InputLayout::SBH) &&
        basicParams.b * basicParams.n * basicParams.d > MAX_STRIDE_LIMIT) {
      // SBH: BND > 65535
      tilingKey = specMMTilingKey;
      return tilingKey;
    }
    // SBH: BND <= 65535
    return tilingKey;
  }

  bool IsCapable() override {
    if ((td_.opInfo.get_n() * td_.opInfo.get_g() * Align(td_.opInfo.get_sQ(), FP16_BLOCK_ELES) *
         Align(td_.opInfo.get_sKV(), FP16_BLOCK_ELES)) <= BEST_BASIC_BLOCK_NUM) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ubngs1s2Bb isCapable check true.");
      return true;
    }
    OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ubngs1s2Bb isCapable check false.");
    return false;
  }

  ge::graphStatus GetPlatformInfo() override {
    OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Get platform info begin.");

    auto platformInfoPtr = context_->GetPlatformInfo();
    if (platformInfoPtr == nullptr) {
      auto compileInfoPtr = reinterpret_cast<const FlashAttentionScoreGradCompileInfo*>(context_->GetCompileInfo());
      OPS_ERR_IF(compileInfoPtr == nullptr, OPS_REPORT_CUBE_INNER_ERR(context_, "compile_info is null"),
                 return ge::GRAPH_FAILED);

      aicoreParams_.blockDim = compileInfoPtr->aivNum;
      aicoreParams_.aicNum = compileInfoPtr->aicNum;
      aicoreParams_.ubSize = compileInfoPtr->ubSize;
      aicoreParams_.l1Size = compileInfoPtr->l1Size;
      aicoreParams_.l0aSize = compileInfoPtr->l0aSize;
      aicoreParams_.l0bSize = compileInfoPtr->l0bSize;
      aicoreParams_.l0cSize = compileInfoPtr->l0cSize;
    } else {
      auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
      aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
      aicoreParams_.aicNum = ascendcPlatform.GetCoreNumAic();
      ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, aicoreParams_.ubSize);
      ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L1, aicoreParams_.l1Size);
      ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, aicoreParams_.l0aSize);
      ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, aicoreParams_.l0bSize);
      ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, aicoreParams_.l0cSize);
    }

    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus SetAttenMaskTilingInfo() {
    OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ubngs1s2Bb set attenmask tiling info begin.");
    basicParams.attenMaskCompressMode = NO_COMPRESS_MODE;
    auto attenMaskDesc = context_->GetOptionalInputDesc(ATTEN_MASK);
    if (attenMaskDesc != nullptr) {
      auto attenMaskDtype = attenMaskDesc->GetDataType();
      if (attenMaskDtype != ge::DT_BOOL && attenMaskDtype != ge::DT_UINT8) {
        OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "AttenMaskDtype(%u) is not bool or uint8.",
                  attenMaskDtype);
        return ge::GRAPH_PARAM_INVALID;
      }
    }
    auto attenMaskShape = context_->GetOptionalInputShape(ATTEN_MASK);
    if (attenMaskShape != nullptr && attenMaskShape->GetStorageShape().GetDimNum() != 0) {
      td_.opInfo.set_hasAttenMask(1);
      auto storageShape = attenMaskShape->GetStorageShape();
      auto maskShapeDims = storageShape.GetDimNum();
      if (maskShapeDims == ATTEN_MASK_TYPE_11SS_DIM_NUM) {
        td_.opInfo.set_attenMaskShapeType(ATTEN_MASK_SHAPE_TYPE_SS);
      } else if (maskShapeDims == DIM_COUNT_NUM_4) {
        auto dim0 = storageShape.GetDim(DIM_0);
        auto dim1 = storageShape.GetDim(DIM_1);
        if (dim0 == 1 && dim1 == 1) {
          td_.opInfo.set_attenMaskShapeType(ATTEN_MASK_SHAPE_TYPE_SS);
        } else if (dim0 == td_.opInfo.get_b() && dim1 == 1) {
          td_.opInfo.set_attenMaskShapeType(ATTEN_MASK_SHAPE_TYPE_B1SS);
        } else if (dim0 == td_.opInfo.get_b() && dim1 == td_.opInfo.get_n() * td_.opInfo.get_g()) {
          td_.opInfo.set_attenMaskShapeType(ATTEN_MASK_SHAPE_TYPE_BNSS);
        } else {
          OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "unsupport attenmask shape type.");
          return ge::GRAPH_PARAM_INVALID;
        }
        OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ubngs1s2Bb get attenmask shape dims success.");
      } else {
        OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ubngs1s2Bb set attenmask shape dims fail.");
        return ge::GRAPH_PARAM_INVALID;
      }
      basicParams.attenMaskS2Size = storageShape.GetDim(maskShapeDims - LAST_AXIS_IDX);
      basicParams.attenMaskS1Size = storageShape.GetDim(maskShapeDims - SEC_LAST_AXIS_IDX);
      int sparseMode = NO_MASK;
      if (context_->GetAttrs()->GetAttrNum() > static_cast<size_t>(SPARSE_MODE)) {
        sparseMode = *context_->GetAttrs()->GetAttrPointer<int>(SPARSE_MODE);  // 7
      }
      if (sparseMode == LEFT_UP_CAUSAL) {
        basicParams.attenMaskCompressMode = LEFT_UP_CAUSAL_MODE;
      } else if (sparseMode == RIGHT_DOWN_CAUSAL) {
        basicParams.attenMaskCompressMode = RIGHT_DOWN_CAUSAL_MODE;
      } else if (sparseMode > RIGHT_DOWN_CAUSAL) {
        return ge::GRAPH_PARAM_INVALID;
      }
    }
    td_.opInfo.set_attenMaskS2Size(basicParams.attenMaskS2Size);
    td_.opInfo.set_attenMaskCompressMode(basicParams.attenMaskCompressMode);
    OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb set attenmask tiling info success.");
    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus SetDataTypeInfo() {
    /* 3. 获取data type信息 */
    td_.opInfo.set_inputDType(static_cast<uint32_t>(context_->GetInputDesc(QUERY)->GetDataType()));
    basicParams.dataType = td_.opInfo.get_inputDType();
    int64_t inputDTypeSize =
        static_cast<int64_t>(GetSizeByDataType(static_cast<ge::DataType>(td_.opInfo.get_inputDType())));
    if (inputDTypeSize >= ge::kDataTypeSizeBitOffset) {
      return ge::GRAPH_PARAM_INVALID;
    }
    OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb get inputDTypeSize success.");

    // bf16 场景下使用fp32(4 bytes)来进行vector的数据计算。
    uint32_t vecCalcDTypeSize = FP32_BYTES_NUM;

    td_.opInfo.set_vecCalcDTypeSize(vecCalcDTypeSize);
    td_.opInfo.set_inputDTypeSize(inputDTypeSize);
    int64_t sKVAlignSize = Align(td_.opInfo.get_sKV() * inputDTypeSize, BLOCK_BYTE);
    td_.opInfo.set_sKVAlignSize(sKVAlignSize);
    td_.opInfo.set_sKVAlign(sKVAlignSize / inputDTypeSize);
    td_.opInfo.set_sKVAlignByte(Align(td_.opInfo.get_sKV(), BLOCK_BYTE));
    td_.opInfo.set_originalDAlign(Align(td_.opInfo.get_d() * inputDTypeSize, BLOCK_BYTE) / inputDTypeSize);
    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus SetBaseAttrsInfo() {
    /* 1. 获取属性信息 */
    td_.opInfo.set_scaleValue(*context_->GetAttrs()->GetAttrPointer<float>(SCALE_VALUE));
    td_.opInfo.set_keepProb(*context_->GetAttrs()->GetAttrPointer<float>(KEEP_PROB));
    td_.opInfo.set_preTokens(*context_->GetAttrs()->GetAttrPointer<uint32_t>(PRE_TOKENS));
    td_.opInfo.set_nextTokens(*context_->GetAttrs()->GetAttrPointer<uint32_t>(NEXT_TOKENS));
    td_.opInfo.set_headNum(*context_->GetAttrs()->GetAttrPointer<uint32_t>(HEAD_NUM));
    const char* inputLayout = context_->GetAttrs()->GetAttrPointer<char>(INPUT_LAYOUT);
    if (strcmp(inputLayout, BSH_STR) == 0) {
      td_.opInfo.set_inputLayout(static_cast<uint32_t>(InputLayout::BSH));
    } else if (strcmp(inputLayout, SBH_STR) == 0) {
      td_.opInfo.set_inputLayout(static_cast<uint32_t>(InputLayout::SBH));
    } else if (strcmp(inputLayout, BNSD_STR) == 0) {
      td_.opInfo.set_inputLayout(static_cast<uint32_t>(InputLayout::BNSD));
    } else if (strcmp(inputLayout, BSND_STR) == 0) {
      td_.opInfo.set_inputLayout(static_cast<uint32_t>(InputLayout::BSND));
    } else {
      OPS_REPORT_VECTOR_INNER_ERR("FlashAttentionScoreGradUbngs1s2BbTiling", "INPUT_LAYOUT %s is not valid.",
                                  inputLayout);
      return ge::GRAPH_PARAM_INVALID;
    }
    td_.opInfo.set_precisionMode(HIGH_PRECISION);

    basicParams.precisionMode = td_.opInfo.get_precisionMode();
    basicParams.layout = td_.opInfo.get_inputLayout();
    OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb get input layout success.");
    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus GetDimAttrsInfo() {
    /* 2. 获取shape和轴信息 */
    const gert::Shape& queryShape = context_->GetInputShape(QUERY)->GetStorageShape();
    const gert::Shape& keyShape = context_->GetInputShape(KEY)->GetStorageShape();
    if (td_.opInfo.get_inputLayout() == static_cast<uint32_t>(InputLayout::BNSD) ||
        td_.opInfo.get_inputLayout() == static_cast<uint32_t>(InputLayout::BSND)) {
      if (queryShape.GetDimNum() != BNSD_DIM_NUM || keyShape.GetDimNum() != BNSD_DIM_NUM) {
        OPS_REPORT_VECTOR_INNER_ERR("FlashAttentionScoreGradUbngs1s2BbTiling", "Shape size is not = 4[%zu, %zu].",
                                    queryShape.GetDimNum(), keyShape.GetDimNum());
        return ge::GRAPH_PARAM_INVALID;
      }
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb get input dim success.");
      size_t layoutIdx = static_cast<size_t>(td_.opInfo.get_inputLayout());
      td_.opInfo.set_b(queryShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][B]));
      td_.opInfo.set_sQ(queryShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][S]));
      td_.opInfo.set_sKV(keyShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][S]));
      td_.opInfo.set_d(queryShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][AXIS4_D]));
      td_.opInfo.set_hQ(queryShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][AXIS4_N]) * td_.opInfo.get_d());
      td_.opInfo.set_hKV(keyShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][AXIS4_N]) * td_.opInfo.get_d());

      CHECK_ZERO(td_.opInfo.get_hKV());
      td_.opInfo.set_g(td_.opInfo.get_hQ() / td_.opInfo.get_hKV());
      CHECK_ZERO(td_.opInfo.get_g());
      td_.opInfo.set_n(td_.opInfo.get_headNum() / td_.opInfo.get_g());
    } else {
      if (queryShape.GetDimNum() != BSH_SBH_DIM_NUM || keyShape.GetDimNum() != BSH_SBH_DIM_NUM) {
        OPS_REPORT_VECTOR_INNER_ERR("FlashAttentionScoreGradUbngs1s2BbTiling", "Shape size is not = 3[%zu, %zu].",
                                    queryShape.GetDimNum(), keyShape.GetDimNum());
        return ge::GRAPH_PARAM_INVALID;
      }
      size_t layoutIdx = static_cast<size_t>(td_.opInfo.get_inputLayout());
      td_.opInfo.set_b(queryShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][B]));
      td_.opInfo.set_sQ(queryShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][S]));
      td_.opInfo.set_sKV(keyShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][S]));
      td_.opInfo.set_hQ(queryShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][H]));
      td_.opInfo.set_hKV(keyShape.GetDim(LAYOUT_TO_AXIS[layoutIdx][H]));

      CHECK_ZERO(td_.opInfo.get_hKV());
      td_.opInfo.set_g(td_.opInfo.get_hQ() / td_.opInfo.get_hKV());
      CHECK_ZERO(td_.opInfo.get_g());
      td_.opInfo.set_n(td_.opInfo.get_headNum() / td_.opInfo.get_g());
      CHECK_ZERO(td_.opInfo.get_n());
      td_.opInfo.set_d(td_.opInfo.get_hKV() / td_.opInfo.get_n());
    }
    basicParams.b = td_.opInfo.get_b();
    basicParams.n = td_.opInfo.get_n() * td_.opInfo.get_g();
    basicParams.d = td_.opInfo.get_d();
    return CheckShapeValid(context_, basicParams.b, basicParams.n, td_.opInfo.get_sQ(), basicParams.d);
  }

  ge::graphStatus GetShapeAttrsInfo() override {
    /* 1. 获取属性信息 */
    if (SetBaseAttrsInfo() != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb set base attrs info fail.");
      return ge::GRAPH_PARAM_INVALID;
    }

    /* 2. 获取shape和轴信息 */
    if (GetDimAttrsInfo() != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb get dim attrs info fail.");
      return ge::GRAPH_PARAM_INVALID;
    }

    /* 3. 获取data type信息 */
    if (SetDataTypeInfo() != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb set data type fail.");
      return ge::GRAPH_PARAM_INVALID;
    }

    /* 4. 获取pse shape信息 */
    if (GetPseInfo() != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb get pse info fail.");
      return ge::GRAPH_PARAM_INVALID;
    }

    if (SetAttenMaskTilingInfo() != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb set attenmask tiling fail.");
      return ge::GRAPH_PARAM_INVALID;
    }
    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus GetPseInfo() {
    /*
    Get pse input info
    */
    td_.opInfo.set_existPse(0);
    auto pseShape = context_->GetOptionalInputShape(PSE_SHIFT);
    if (pseShape != nullptr && pseShape->GetStorageShape().GetDimNum() != 0) {
      // pse support [B N S1 S2](0) + [B N 1 S2](1) + [1 N S1 S2](2)
      // 4.1 模板不涉及alibi，收益不大
      td_.opInfo.set_existPse(1);
      auto storageShape = pseShape->GetStorageShape();
      auto pseShapeDims = storageShape.GetDimNum();
      if (pseShapeDims == DIM_4) {
        auto dim0 = storageShape.GetDim(DIM_0);
        auto dim1 = storageShape.GetDim(DIM_1);
        auto dim2 = storageShape.GetDim(DIM_2);
        auto dim3 = storageShape.GetDim(DIM_3);
        td_.opInfo.set_pseSq(dim2);
        // [B N S1 S2](0)  [B N 1 S2](1)  [1 N S1 S2](2)shape判断
        uint32_t shapeN1 = td_.opInfo.get_n() * td_.opInfo.get_g();
        bool isBNS = (dim0 == td_.opInfo.get_b()) && (dim1 == shapeN1) && (dim3 == td_.opInfo.get_sKV());
        bool isBNSS = isBNS && (dim2 == td_.opInfo.get_sQ());
        bool isBN1S = isBNS && (dim2 == 1);
        bool is1NSS =
            (dim0 == 1) && (dim1 == shapeN1) && (dim2 == td_.opInfo.get_sQ()) && (dim3 == td_.opInfo.get_sKV());
        // 设置shape类型
        if (is1NSS) {
          td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_1NSS);
        } else if (isBN1S) {
          td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_BN1S);
        } else if (isBNSS) {
          td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_BNSS);
        } else {
          OPS_LOG_E("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb not support pseShape [%ld, %ld, %ld, %ld]",
                    dim0, dim1, dim2, dim3);
          return ge::GRAPH_FAILED;
        }
      } else {
        OPS_LOG_E("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb not support pseShape dim num: %zu",
                  pseShapeDims);
        return ge::GRAPH_FAILED;
      }
    }
    return ge::GRAPH_SUCCESS;
  }

  uint32_t GetApiTmpBufferSize(int64_t bIn, int64_t sQ, int64_t sKVAlign) {
    // softmax和dropout对应的vector计算的shape是一样的
    auto shape = ge::Shape({bIn * td_.opInfo.get_n() * td_.opInfo.get_g() * sQ, sKVAlign});

    uint32_t softmaxTmpSize = GetSoftMaxMinTmpSize(shape, td_.opInfo.get_vecCalcDTypeSize(), true);
    uint32_t dropoutTmpSize = GetDropOutMinTmpSize(shape, td_.opInfo.get_vecCalcDTypeSize(), true);

    return std::max(softmaxTmpSize, dropoutTmpSize);
  }

  bool CheckArgsLegal(int64_t bIn) {
    int64_t sQ = td_.opInfo.get_sQ();
    int64_t sKVAlign = td_.opInfo.get_sKVAlign();
    // 计算vecIn1和vecIn2的d轴需要限制，如果d轴大于128，这里仍然使用128计算，并做切分，否则
    // 在fp32场景下只能使用 32 * 32 的基本块；
    // 在fp16场景下只能用 64 * 64 的基本块
    // 在bf16场景下仍然可以用 64 * 128 的基本块
    // 由于d轴在softmaxgrad计算之后可以reduce掉，我们默认会对超过128的d轴做核内切分
    int64_t vecInQue1Size = BEST_BASIC_BLOCK_SIZE / 2;

    /* Bf16下使用fp32进行vector计算，这一步已经在vecCalcDTypeSize的计算中区分。 */
    int64_t vecClc1Size = BEST_BASIC_BLOCK_SIZE;
    int64_t softmaxGradQueSize = BEST_BASIC_BLOCK_SIZE / 2;
    int64_t dropoutQueSize = BEST_BASIC_BLOCK_SIZE / 4;
    int64_t maxSumQueSize = BEST_BASIC_BLOCK_SIZE;
    // bmm的TensorA + TensorB < 512KB (L1_Size)
    int64_t inputDTypeSize = td_.opInfo.get_inputDTypeSize();
    int64_t sQAlign4Input = Align(td_.opInfo.get_sQ() * inputDTypeSize, BLOCK_BYTE) / inputDTypeSize;
    int64_t sKVAlign4Input = td_.opInfo.get_sKVAlign();
    int64_t origDAlign4Input = td_.opInfo.get_originalDAlign();
    int64_t baseByteSize = td_.singleCoreParams.get_bCvInner() * td_.opInfo.get_n() * inputDTypeSize;

    // qDx(B, N2, G, S1_Align, D_Align)   kV(B, N2, 1, S2_Align, D_Align)   s1s2(B, N2, G, S2_Align, S2_Align)
    int64_t qDxByteSize = baseByteSize * td_.opInfo.get_g() * sQAlign4Input * origDAlign4Input;
    int64_t kVByteSize = baseByteSize * td_.opInfo.get_g() * sKVAlign4Input * origDAlign4Input;
    int64_t s1s2ByteSize = baseByteSize * td_.opInfo.get_g() * sQAlign4Input * sKVAlign4Input;
    std::array<int64_t, CHECKNUM_FOR_L1_SIZE> numbers = {qDxByteSize, kVByteSize, s1s2ByteSize};
    std::sort(numbers.begin(), numbers.end());
    int64_t largest = numbers[2];
    int64_t secondLargest = numbers[1];

    if ((largest + secondLargest) >= L1_BYTE_SIZE) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb bmm TensorA+B size:%ld out of range L1 size:%ld!",
                largest + secondLargest, L1_BYTE_SIZE);
      return false;
    }
    td_.singleCoreParams.set_innerTmpBufSize(vecClc1Size);
    td_.singleCoreParams.set_vecQueIn1Size(vecInQue1Size);
    int64_t queBufferSizeUb = vecInQue1Size * 2 + vecClc1Size * 2 + softmaxGradQueSize + maxSumQueSize + dropoutQueSize;
    // softmaxGrad，softmax，dropout计算所需要的tmpSize
    uint32_t maxTmpBufferSize = GetApiTmpBufferSize(bIn, sQ, sKVAlign);
    OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "softmaxGradComputeSize is:%u.", maxTmpBufferSize);
    uint32_t usedBufferSize = static_cast<uint32_t>(queBufferSizeUb) + maxTmpBufferSize + SOFTMAX_REMAIN_SIZE;

    int64_t ubSizeRemain = static_cast<int64_t>(aicoreParams_.ubSize) - usedBufferSize;
    if (ubSizeRemain >= 0) {
      td_.splitCoreParams.set_apiClcQueueSize(ubSizeRemain + API_RSDV_BUFFER_SIZE);
      return true;
    }
    return false;
  }

  ge::graphStatus DropoutTiling() {
    if (td_.opInfo.get_keepProb() < 1.0 && context_->GetOptionalInputShape(DROP_MASK) != nullptr &&
        context_->GetOptionalInputShape(DROP_MASK)->GetStorageShape().GetDimNum() != 0) {
      // 120KB FP16Tensor, 60KB U8Tensor, 8KB MaskTensor, 512B HelpTensor which less than UB(192KB).
      // singleUBProcessNum: UB最大处理FP16数据大小，需保证能被128整除.
      uint32_t castBufferLen = 60 * 1024;
      uint32_t outputBufferLen = 30 * 1024;
      uint32_t inputBufferLen = 4 * 1024;
      uint32_t singleUBProcessNum = castBufferLen / 2;

      uint32_t dropoutWorkspaceLen =
          td_.opInfo.get_b() * td_.opInfo.get_n() * td_.opInfo.get_g() * td_.opInfo.get_sQ() * td_.opInfo.get_sKV();
      dropoutWorkspaceLen = CeilCommon(dropoutWorkspaceLen, BLOCK_BYTE) * BLOCK_BYTE;
      td_.opInfo.set_dropoutWorkspaceLen(dropoutWorkspaceLen);
      td_.dropoutCastParams.set_castBufferLen(castBufferLen);
      td_.dropoutCastParams.set_outputBufferLen(outputBufferLen);
      td_.dropoutCastParams.set_inputBufferLen(inputBufferLen);
      td_.dropoutCastParams.set_singleUBProcessNum(singleUBProcessNum);
      return ge::GRAPH_SUCCESS;
    }
    td_.opInfo.set_dropoutWorkspaceLen(0);
    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus DoOpTiling() override {
    (void)DropoutTiling();

    auto ret = DoCoresSplitTiling();
    if (ret != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb get core split tiling fail.");
      return ge::GRAPH_PARAM_INVALID;
    }

    ret = DoInCoreTiling();
    if (ret != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb get single core tiling fail.");
      return ge::GRAPH_PARAM_INVALID;
    }

    ret = DoMulsTiling();
    if (ret != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb get muls tiling fail.");
      return ge::GRAPH_PARAM_INVALID;
    }

    ret = CheckAttenMaskShape();
    if (ret != ge::GRAPH_SUCCESS) {
      OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Ungs1s2Bb check atten mask shape fail.");
      return ge::GRAPH_PARAM_INVALID;
    }

    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus CheckAttenMaskShape() {
    if (basicParams.attenMaskCompressMode == NO_COMPRESS_MODE) {
      bool invalid = td_.opInfo.get_hasAttenMask() != 0 &&
                     ((int64_t)basicParams.attenMaskS1Size * (int64_t)basicParams.attenMaskS2Size <
                      (int64_t)td_.opInfo.get_sQ() * (int64_t)td_.opInfo.get_sKV());
      if (invalid) {
        OPS_LOG_E("CheckAttenMaskShape", "atten mask shape [%u,%u] is invalid.", basicParams.attenMaskS1Size,
                  basicParams.attenMaskS2Size);
        return ge::GRAPH_PARAM_INVALID;
      }
    } else {
      if (basicParams.attenMaskS1Size != basicParams.attenMaskS2Size) {
        OPS_LOG_E("Atten Mask Compress", "atten mask shape is not square.");
        return ge::GRAPH_PARAM_INVALID;
      }
      if (basicParams.attenMaskS2Size < std::max(td_.opInfo.get_sQ(), td_.opInfo.get_sKV()) * MULT_BASE) {
        OPS_LOG_E("Atten Mask Compress", "atten mask shape is small, try setting it to [2048, 2048].");
        return ge::GRAPH_PARAM_INVALID;
      }
    }
    return ge::GRAPH_SUCCESS;
  }

  /* 切N轴分核的基础是，当前g * sQ * sKV比较小，且N比较大的时候，需要将N拆成nOut和nIn，其中
   * nIn参与到内层g * sQ * sKV的数据搬运和计算中，保证单次数据搬运和vector算力可以发挥到极致。
   * nOut参与到外层的分核，b和nOut一起分核。
   * 当前计算nOut和nIn的方式是让 nIn * g * sQ * sKV 尽可能的接近128 *
   * 128。然后将ubsize累加之后计算是否总的ub大小够放， 如果够的话直接使用当前nIn，否则使用64 * 128，64 *
   * 64的拼凑方式。 */
  ge::graphStatus DoCoresSplitTiling() {
    int64_t nGSqSkvAlign = td_.opInfo.get_n() * td_.opInfo.get_g() * td_.opInfo.get_sQ() * td_.opInfo.get_sKVAlign();
    if (nGSqSkvAlign == 0) {
      return ge::GRAPH_PARAM_INVALID;
    }
    int64_t bIn = BEST_BASIC_BLOCK_SIZE / (nGSqSkvAlign * td_.opInfo.get_vecCalcDTypeSize());
    int64_t b = td_.opInfo.get_b();
    if (bIn > b) {
      bIn = b;
    }
    td_.singleCoreParams.set_bIn(bIn);
    int64_t cvRatio = CV_RATIO;
    int64_t bCvRatioInner = bIn * cvRatio;
    if (bCvRatioInner > b) {
      bCvRatioInner = bIn;
      cvRatio = 1;
    }
    td_.singleCoreParams.set_bCvInner(bCvRatioInner);
    td_.singleCoreParams.set_bCvRatio(cvRatio);

    bool ret = CheckArgsLegal(bIn);
    if (!ret) {
      return ge::GRAPH_PARAM_INVALID;
    }
    int64_t inputDTypeSize = td_.opInfo.get_inputDTypeSize();
    int64_t bInNGSq = bIn * td_.opInfo.get_n() * td_.opInfo.get_g() * td_.opInfo.get_sQ();
    int64_t dMaxAlign = Align(BEST_BASIC_BLOCK_NUM / bInNGSq, FP16_BLOCK_ELES);

    uint32_t clcDInner = FP16_BLOCK_ELES;
    // 最大不能超过32k，所以需要减16B
    if (dMaxAlign > FP16_BLOCK_ELES) {
      clcDInner = dMaxAlign - FP16_BLOCK_ELES;
    }
    uint32_t dSize = (td_.opInfo.get_d() / clcDInner) + ((td_.opInfo.get_d() % clcDInner == 0) ? 0 : 1);
    uint32_t dInnerTail = td_.opInfo.get_d() - (dSize - 1) * clcDInner;
    uint32_t dInnerTailAlign = Align(dInnerTail * inputDTypeSize, BLOCK_BYTE) / inputDTypeSize;

    td_.singleCoreParams.set_clcDInner(clcDInner);
    td_.singleCoreParams.set_dSize(dSize);
    td_.singleCoreParams.set_dInnerTail(dInnerTail);
    td_.singleCoreParams.set_dInnerTailAlign(dInnerTailAlign);

    int64_t bOut = CeilCommon(td_.opInfo.get_b(), bCvRatioInner);
    td_.splitCoreParams.set_bOut(bOut);

    td_.singleCoreParams.set_singleCoreBatchRange(CeilCommon(td_.splitCoreParams.get_bOut(), aicoreParams_.blockDim));
    td_.splitCoreParams.set_usedCoreNum(
        CeilCommon(td_.splitCoreParams.get_bOut(), td_.singleCoreParams.get_singleCoreBatchRange()));

    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus DoInCoreTiling() {
    td_.singleCoreParams.set_subRange(CeilCommon(td_.opInfo.get_sQ(), PER_SUB_RANGE));
    td_.singleCoreParams.set_subMask(SINGLE_VEC_INST_DATASIZE / td_.opInfo.get_vecCalcDTypeSize());
    td_.singleCoreParams.set_subMaskTail((td_.opInfo.get_sQ() % PER_SUB_RANGE) *
                                         (BLOCK_BYTE / td_.opInfo.get_vecCalcDTypeSize()));
    td_.singleCoreParams.set_sKVAlignBlockNum(td_.opInfo.get_sKVAlign() * FP32_BYTES_NUM / BLOCK_BYTE);

    int64_t rightPadding = td_.opInfo.get_sKVAlign() - td_.opInfo.get_sKV();
    int64_t dstStride = 0;

    // kernel是按f32计算，如果rightpadding大于8，padding超过32B，会有问题，因此需要额外处理
    if (td_.opInfo.get_sKVAlign() - td_.opInfo.get_sKV() >= FP32_BLOCK_ELES) {
      dstStride = 1;
      rightPadding = td_.opInfo.get_sKVAlign() - td_.opInfo.get_sKV() - FP32_BLOCK_ELES;
    }
    td_.singleCoreParams.set_dstStride(dstStride);
    td_.singleCoreParams.set_rightPadding(rightPadding);

    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus DoMulsTiling() {
    OPS_LOG_D("FlashAttentionScoreGradUbngs1s2BbTiling", "Do muls tiling begin.");
    uint64_t allNumQuery = static_cast<uint64_t>(td_.opInfo.get_b()) * td_.opInfo.get_n()
            * td_.opInfo.get_g() * td_.opInfo.get_sQ() * td_.opInfo.get_d();
    uint64_t allNumKv = static_cast<uint64_t>(td_.opInfo.get_b()) * td_.opInfo.get_n()
            * td_.opInfo.get_sKV() * td_.opInfo.get_d();

    uint32_t used_core_num = td_.splitCoreParams.get_usedCoreNum();
    if (used_core_num == 0) {
      return ge::GRAPH_PARAM_INVALID;
    }
    
    uint32_t postUbBaseSize = (aicoreParams_.ubSize) / POST_COEX_NODE / BUFFER_NUM / BASE_LEN_256 * BASE_LEN_256;
    uint32_t qPostBaseNum = postUbBaseSize / FP16_BYTES_NUM;
    uint32_t qPostBlockTotal = allNumQuery;
    uint64_t qSizeAlign = (qPostBlockTotal + BASE_LEN_256 - 1) / WORKSPACE_ALIGN_SIZE * WORKSPACE_ALIGN_SIZE * FP16_BYTES_NUM;
    uint32_t qPostTailNumTmp = qPostBlockTotal % qPostBaseNum;
    uint32_t qPostTailNum = qPostTailNumTmp == 0 ? qPostBaseNum : qPostTailNumTmp;
    uint32_t qPostBlockOuterTotal = (qPostBlockTotal + qPostBaseNum - 1) / qPostBaseNum;
    uint32_t qPostBlockFactor = (qPostBlockOuterTotal + used_core_num - 1) / used_core_num;

    uint32_t kvPostBaseNum = postUbBaseSize / FP16_BYTES_NUM;
    uint32_t kvPostBlockTotal = allNumKv;
    uint64_t kvSizeAlign = (kvPostBlockTotal + WORKSPACE_ALIGN_SIZE - 1) / WORKSPACE_ALIGN_SIZE * WORKSPACE_ALIGN_SIZE * FP16_BYTES_NUM;
    uint32_t kvPostTailNumTmp = kvPostBlockTotal % kvPostBaseNum;
    uint32_t kvPostTailNum = kvPostTailNumTmp == 0 ? kvPostBaseNum : kvPostTailNumTmp;
    uint32_t kvPostBlockOuterTotal = (kvPostBlockTotal + kvPostBaseNum - 1) / kvPostBaseNum;
    uint32_t kvPostBlockFactor = (kvPostBlockOuterTotal + used_core_num - 1) / used_core_num;

    td_.postTilingData.set_coreNum(used_core_num);
    td_.postTilingData.set_scaleValue(td_.opInfo.get_scaleValue());
    td_.postTilingData.set_postUbBaseSize(postUbBaseSize);
    td_.postTilingData.set_qPostBlockFactor(qPostBlockFactor);
    td_.postTilingData.set_qPostBlockTotal(qPostBlockTotal);
    td_.postTilingData.set_qPostBaseNum(qPostBaseNum);
    td_.postTilingData.set_qPostTailNum(qPostTailNum);
    td_.postTilingData.set_qSizeAlign(qSizeAlign);

    td_.postTilingData.set_kvPostBlockFactor(kvPostBlockFactor);
    td_.postTilingData.set_kvPostBlockTotal(kvPostBlockTotal);
    td_.postTilingData.set_kvPostBaseNum(kvPostBaseNum);
    td_.postTilingData.set_kvPostTailNum(kvPostTailNum);
    td_.postTilingData.set_kvSizeAlign(kvSizeAlign);

    uint32_t allNumDropGm =
        td_.opInfo.get_b() * td_.opInfo.get_n() * td_.opInfo.get_g() * td_.opInfo.get_sQ() * td_.opInfo.get_sKVAlign();
    uint32_t allNumMulGm =
        td_.opInfo.get_b() * td_.opInfo.get_n() * td_.opInfo.get_g() * td_.opInfo.get_sQ() * td_.opInfo.get_sKVAlign();
    
    td_.opInfo.set_dropGmWorkspaceLen(CeilCommon(allNumDropGm * FP16_BYTES_NUM, WORKSPACE_ALIGN_SIZE) *
                                          WORKSPACE_ALIGN_SIZE);
    td_.opInfo.set_mulGmWorkspaceLen(CeilCommon(allNumMulGm * FP16_BYTES_NUM, WORKSPACE_ALIGN_SIZE) *
                                         WORKSPACE_ALIGN_SIZE);
    td_.opInfo.set_dqWorkspaceLen(CeilCommon(allNumQuery * FP32_BYTES_NUM, BLOCK_BYTE) * BLOCK_BYTE);
    td_.opInfo.set_dkWorkspaceLen(CeilCommon(allNumKv * FP32_BYTES_NUM, BLOCK_BYTE) * BLOCK_BYTE);
    td_.opInfo.set_dropGmWorkspaceLen(CeilCommon(allNumDropGm * FP16_BYTES_NUM, BLOCK_BYTE) * BLOCK_BYTE);
    td_.opInfo.set_mulGmWorkspaceLen(CeilCommon(allNumMulGm * FP16_BYTES_NUM, BLOCK_BYTE) * BLOCK_BYTE);

    td_.opInfo.set_dqWorkspaceLen(CeilCommon(allNumQuery * FP32_BYTES_NUM, WORKSPACE_ALIGN_SIZE) *
                                      WORKSPACE_ALIGN_SIZE);
    td_.opInfo.set_dkWorkspaceLen(CeilCommon(allNumKv * FP32_BYTES_NUM, WORKSPACE_ALIGN_SIZE) *
                                      WORKSPACE_ALIGN_SIZE);

    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus SetMm1AndMm2Tiling(matmul_tiling::MatmulApiTiling& mm1AndMm2, int32_t l1SizeRemain, int64_t bIn,
                                     TCubeTiling& mm1AndMm2Tiling) {
    mm1AndMm2.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16,
                       false);
    mm1AndMm2.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16,
                       true);
    mm1AndMm2.SetCType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND,
                       matmul_tiling::DataType::DT_FLOAT16);
    mm1AndMm2.SetShape(td_.opInfo.get_sQ(), td_.opInfo.get_sKV(), td_.opInfo.get_d());
    mm1AndMm2.SetOrgShape(td_.opInfo.get_sQ(), td_.opInfo.get_sKV(), td_.opInfo.get_d());
    mm1AndMm2.SetALayout(td_.opInfo.get_b(), td_.opInfo.get_sQ(), td_.opInfo.get_n(), td_.opInfo.get_g(),
                         td_.opInfo.get_d());
    mm1AndMm2.SetBLayout(td_.opInfo.get_b(), td_.opInfo.get_sKV(), td_.opInfo.get_n(), 1, td_.opInfo.get_d());
    mm1AndMm2.SetCLayout(td_.opInfo.get_b(), td_.opInfo.get_sQ(), td_.opInfo.get_n(), td_.opInfo.get_g(),
                         td_.opInfo.get_sKV());

    uint32_t layout = td_.opInfo.get_inputLayout();
    if (layout == static_cast<uint32_t>(InputLayout::BSH) || layout == static_cast<uint32_t>(InputLayout::BSND)) {
      mm1AndMm2.SetBatchNum(td_.opInfo.get_n() * td_.opInfo.get_g());
    } else {
      // SBH BNSD
      mm1AndMm2.SetBatchNum(bIn * td_.singleCoreParams.get_bCvRatio() * td_.opInfo.get_n() * td_.opInfo.get_g());
    }
    if (mm1AndMm2.GetTiling(mm1AndMm2Tiling) != 0) {
      return ge::GRAPH_PARAM_INVALID;
    }

    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus SetMm31Tiling(matmul_tiling::MatmulApiTiling& mm31, int32_t l1SizeRemain, int64_t bIn,
                                TCubeTiling& mm31Tiling) {
    mm31.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16,
                  false);
    mm31.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16,
                  false);
    mm31.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16);
    mm31.SetShape(td_.opInfo.get_sQ(), td_.opInfo.get_d(), td_.opInfo.get_sKV());
    mm31.SetOrgShape(td_.opInfo.get_sQ(), td_.opInfo.get_d(), td_.opInfo.get_sKV());
    mm31.SetALayout(td_.opInfo.get_b(), td_.opInfo.get_sQ(), td_.opInfo.get_n(), td_.opInfo.get_g(),
                    td_.opInfo.get_sKV());
    mm31.SetBLayout(td_.opInfo.get_b(), td_.opInfo.get_sKV(), td_.opInfo.get_n(), 1, td_.opInfo.get_d());
    mm31.SetCLayout(td_.opInfo.get_b(), td_.opInfo.get_sQ(), td_.opInfo.get_n(), td_.opInfo.get_g(),
                    td_.opInfo.get_d());

    uint32_t layout = td_.opInfo.get_inputLayout();
    if (layout == static_cast<uint32_t>(InputLayout::BSH) || layout == static_cast<uint32_t>(InputLayout::BSND)) {
      mm31.SetBatchNum(td_.opInfo.get_n() * td_.opInfo.get_g());
    } else {
      // SBH BNSD
      mm31.SetBatchNum(bIn * td_.singleCoreParams.get_bCvRatio() * td_.opInfo.get_n() * td_.opInfo.get_g());
    }
    if (mm31.GetTiling(mm31Tiling) != 0) {
      return ge::GRAPH_PARAM_INVALID;
    }

    return ge::GRAPH_SUCCESS;
  }

  ge::graphStatus SetMm32AndMm4Tiling(matmul_tiling::MatmulApiTiling& mm32AndMm4, int32_t l1SizeRemain, int64_t bIn,
                                      TCubeTiling& mm32AndMm4Tiling) {
    mm32AndMm4.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND,
                        matmul_tiling::DataType::DT_FLOAT16, true);
    mm32AndMm4.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND,
                        matmul_tiling::DataType::DT_FLOAT16, false);
    mm32AndMm4.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND,
                        matmul_tiling::DataType::DT_FLOAT16);
    mm32AndMm4.SetShape(td_.opInfo.get_sKV(), td_.opInfo.get_d(), td_.opInfo.get_sQ());
    mm32AndMm4.SetOrgShape(td_.opInfo.get_sKV(), td_.opInfo.get_d(), td_.opInfo.get_sQ());
    mm32AndMm4.SetALayout(td_.opInfo.get_b(), td_.opInfo.get_sKV(), td_.opInfo.get_n(), td_.opInfo.get_g(),
                          td_.opInfo.get_sQ());
    mm32AndMm4.SetBLayout(td_.opInfo.get_b(), td_.opInfo.get_sQ(), td_.opInfo.get_n(), td_.opInfo.get_g(),
                          td_.opInfo.get_d());
    mm32AndMm4.SetCLayout(td_.opInfo.get_b(), td_.opInfo.get_sKV(), td_.opInfo.get_n(), 1, td_.opInfo.get_d());

    uint32_t layout = td_.opInfo.get_inputLayout();
    if (layout == static_cast<uint32_t>(InputLayout::BSH) || layout == static_cast<uint32_t>(InputLayout::BSND)) {
      mm32AndMm4.SetBatchNum(td_.opInfo.get_n() * td_.opInfo.get_g());
    } else {
      // SBH BNSD
      mm32AndMm4.SetBatchNum(bIn * td_.singleCoreParams.get_bCvRatio() * td_.opInfo.get_n() * td_.opInfo.get_g());
    }
    if (mm32AndMm4.GetTiling(mm32AndMm4Tiling) != 0) {
      return ge::GRAPH_PARAM_INVALID;
    }
    return ge::GRAPH_SUCCESS;
  }

  // 4、计算高阶API的tiling
  ge::graphStatus DoLibApiTiling() override {
    // mm tiling
    OPS_LOG_D("FAG_SPLIT_B", "DoLibApiTiling begin.");
    int32_t l1SizeRemain = static_cast<int32_t>(aicoreParams_.l1Size);
    matmul_tiling::MatmulApiTiling mm1AndMm2;
    SetMm1AndMm2Tiling(mm1AndMm2, l1SizeRemain, td_.singleCoreParams.get_bIn(), td_.mm1AndMm2TilingData);

    matmul_tiling::MatmulApiTiling mm31;
    SetMm31Tiling(mm31, l1SizeRemain, td_.singleCoreParams.get_bIn(), td_.mm31TilingData);

    matmul_tiling::MatmulApiTiling mm32AndMm4;
    SetMm32AndMm4Tiling(mm32AndMm4, l1SizeRemain, td_.singleCoreParams.get_bIn(), td_.mm32AndMm4TilingData);

    // vector tiling
    auto softmaxShape =
        Shape({td_.singleCoreParams.get_bIn() * td_.opInfo.get_g() * td_.opInfo.get_n() * td_.opInfo.get_sQ(),
               td_.opInfo.get_sKVAlign()});
    uint32_t softmaxComputeSize = td_.splitCoreParams.get_apiClcQueueSize();
    SoftMaxTilingFunc(softmaxShape, sizeof(float), softmaxComputeSize, td_.softmaxTilingData);
    auto softmaxGradShape =
        Shape({td_.singleCoreParams.get_bIn() * td_.opInfo.get_n() * td_.opInfo.get_g() * td_.opInfo.get_sQ(),
               td_.singleCoreParams.get_clcDInner()});
    int64_t softmaxGradComputeSize = td_.splitCoreParams.get_apiClcQueueSize();
    SoftMaxGradTilingFunc(softmaxGradShape, td_.opInfo.get_vecCalcDTypeSize(), softmaxGradComputeSize,
                          td_.softmaxGradTilingData, true);
    return ge::GRAPH_SUCCESS;
  }

  // 6、计算Workspace 大小
  ge::graphStatus GetWorkspaceSize() override {
    /* OCT-TODO, atttention mask可能需要额外的workspace */
    uint32_t sysLen = WORK_SPACE_BASE_CAL;
    uint32_t syncLen = 3200;
    uint32_t mm1WorkspaceLen = td_.singleCoreParams.get_bIn() * td_.singleCoreParams.get_bCvRatio() *
                               td_.opInfo.get_n() * td_.opInfo.get_g() * td_.opInfo.get_sQ() *
                               td_.opInfo.get_sKVAlign() * td_.opInfo.get_vecCalcDTypeSize() *
                               td_.splitCoreParams.get_usedCoreNum();
    uint32_t mm2WorkspaceLen = mm1WorkspaceLen;
    uint32_t dqWorkspaceLen = td_.opInfo.get_dqWorkspaceLen();
    uint32_t dkWorkspaceLen = td_.opInfo.get_dkWorkspaceLen();
    uint32_t dropOutWorkspaceLen = td_.opInfo.get_dropoutWorkspaceLen();
    uint32_t dropGmWorkspaceLen = td_.opInfo.get_dropGmWorkspaceLen();
    uint32_t mulGmWorkspaceLen = td_.opInfo.get_mulGmWorkspaceLen();

    uint64_t workspaceOffsets = syncLen + dropOutWorkspaceLen + mm1WorkspaceLen + mm2WorkspaceLen;
    td_.postTilingData.set_dqWorkSpaceOffset(workspaceOffsets);

    workspaceOffsets = workspaceOffsets + td_.opInfo.get_dqWorkspaceLen();
    td_.postTilingData.set_dkWorkSpaceOffset(workspaceOffsets);
    
    workspaceOffsets = workspaceOffsets + td_.opInfo.get_dkWorkspaceLen();
    td_.postTilingData.set_dvWorkSpaceOffset(workspaceOffsets);

    workspaceSize_ = sysLen + syncLen + dropOutWorkspaceLen + mm1WorkspaceLen + mm2WorkspaceLen + dqWorkspaceLen +
                     dkWorkspaceLen + dropGmWorkspaceLen + mulGmWorkspaceLen;
    td_.opInfo.set_syncLen(syncLen);
    td_.opInfo.set_mm1WorkspaceLen(mm1WorkspaceLen);
    td_.opInfo.set_mm2WorkspaceLen(mm2WorkspaceLen);
    return ge::GRAPH_SUCCESS;
  }

  // 7、保存Tiling数据
  ge::graphStatus PostTiling() override {
    auto blockdim =
        CalcTschBlockDim(td_.splitCoreParams.get_usedCoreNum(), aicoreParams_.aicNum, aicoreParams_.blockDim);
    OPS_ERR_IF(blockdim == 0,
               OPS_REPORT_VECTOR_INNER_ERR("FlashAttentionScoreGradUbngs1s2BbTiling",
                                           "blockdim is 0, aicNum is %ld, aivNum is %ld.",
                                           aicoreParams_.aicNum, aicoreParams_.blockDim),
               return ge::GRAPH_FAILED);
    context_->SetBlockDim(blockdim);

    size_t* workspaces = context_->GetWorkspaceSizes(1);
    workspaces[0] = workspaceSize_;
    td_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
    context_->GetRawTilingData()->SetDataSize(td_.GetDataSize());
    OPS_LOG_D("FAG split B", "tiling data size is %zu.", td_.GetDataSize());
    return ge::GRAPH_SUCCESS;
  }
};

REGISTER_TILING_TEMPLATE("FlashAttentionScoreGrad", FlashAttentionScoreGradUbngs1s2BbTiling, 10000);

}  // namespace optiling
